#!/bin/bash

# Set the environment variables first before running the command.
export HF_ALLOW_CODE_EVAL=1
export HF_DATASETS_TRUST_REMOTE_CODE=true

TASK="tinyGSM8k"
MODEL_PATH="GSAI-ML/LLaDA-8B-Instruct"
MODEL_NAME="llada_instruct"
GEN_LENGTH=512
BLOCK_LENGTH=32
LOG_DIR="logs/tinyGSM8k/tinyGSM8k_baseline_512"
STEPS=$((GEN_LENGTH / BLOCK_LENGTH))

# no cache
accelerate launch --num_processes=1 eval_llada_baseline.py --tasks ${TASK} --num_fewshot 5 \
    --confirm_run_unsafe_code --model llada_dist \
    --model_args model_path=${MODEL_PATH},gen_length=${GEN_LENGTH},steps=${STEPS},block_length=${BLOCK_LENGTH},show_speed=True 

# prefix cache
accelerate launch --num_processes=1 eval_llada_baseline.py --tasks ${TASK} --num_fewshot 5 \
    --confirm_run_unsafe_code --model llada_dist \
    --model_args model_path=${MODEL_PATH},gen_length=${GEN_LENGTH},steps=${STEPS},block_length=${BLOCK_LENGTH},show_speed=True,use_cache=true 

# dual cache
accelerate launch --num_processes=1 eval_llada_baseline.py --tasks ${TASK} --num_fewshot 5 \
    --confirm_run_unsafe_code --model llada_dist \
    --model_args model_path=${MODEL_PATH},gen_length=${GEN_LENGTH},steps=${STEPS},block_length=${BLOCK_LENGTH},threshold=0.9,show_speed=True,use_cache=true,dual_cache=true 